package facewap;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Arrays;

public class FaceGan {
    static String genModel = "F:/face/gen1.zip";
    static String disModel = "F:/face/dis1.zip";
    static String ganModel = "F:/face/gan1.zip";
    private static JFrame frame;
    private static JPanel panel;

    public static void main(String... args) throws Exception {

        Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);

        MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, GanModel.seed);
        MultiLayerNetwork gen = null;
        if (new File(genModel).exists()) {
            gen = MultiLayerNetwork.load(new File(genModel),true);

        } else{
            gen =null;// new MultiLayerNetwork(generator());
        }
        MultiLayerNetwork dis = null;
        if (new File(disModel).exists()) {
            dis = MultiLayerNetwork.load((new File(disModel)), true);
        } else{
            dis = null;//new MultiLayerNetwork(discriminator());
        }
        MultiLayerNetwork gan = null;
        if (new File(ganModel).exists()) {
            gan = MultiLayerNetwork.load((new File(ganModel)), true);
        } else{
            gan = null;//new MultiLayerNetwork(gan());
        }

        gen.init();
        dis.init();
        gan.init();

        System.out.println(gen.summary());
        System.out.println(dis.summary());
        System.out.println(gan.summary());

        copyParams(gen, dis, gan);

        gen.setListeners(new PerformanceListener(10, true));
        dis.setListeners(new PerformanceListener(10, true));
        gan.setListeners(new PerformanceListener(10, true));

        trainData.reset();


        for (int i = 1; i<= 100000; i++) {
            if (!trainData.hasNext()) {
                trainData.reset();
            }
            // generate data
            INDArray real = trainData.next().getFeatures().muli(2).subi(1);
            int batchSize = (int) real.shape()[0];

            INDArray fakeIn = Nd4j.rand(batchSize, 100);
            INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn);

            DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1));
            DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1));

            DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet));

            dis.fit(data);
            dis.fit(data);

            // Update the discriminator in the GAN network
            updateGan(gen, dis, gan);

            gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1)));


            if (i % 10 == 1) {
                System.out.println("Iteration " + i/10+ " Visualizing...");
                INDArray[] samples = new INDArray[9];
                DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1));

                for (int k = 0; k < 9; k++) {
                    INDArray input = fakeSet2.get(k).getFeatures();
                    //samples[k] = gen.output(input, false);
                    samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input);

                }
                visualize(samples);
            }
            // Copy the GANs generator to gen.
            updateGen(gen, gan);
            if (i % 100 == 0) {
                gen.save(new File(genModel), true);
                dis.save(new File(disModel), true);
                gan.save(new File(ganModel), true);

            }
        }
        // ModelSerializer.writeModel(gan, new File(ganModel + "/minist-model.zip"), true);//保存训练好的网络

        //gen.save(new File("mnist-mlp-generator.dlj"));
    }

    private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
        int genLayerCount = gen.getLayers().length;
        for (int i = 0; i < gan.getLayers().length; i++) {
            if (i < genLayerCount) {
                gen.getLayer(i).setParams(gan.getLayer(i).params());
            } else {
                dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params());
            }
        }
    }

    private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) {
        for (int i = 0; i < gen.getLayers().length; i++) {
            gen.getLayer(i).setParams(gan.getLayer(i).params());
        }
    }

    private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) {
        int genLayerCount = gen.getLayers().length;
        for (int i = genLayerCount; i < gan.getLayers().length; i++) {
            gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params());
        }
    }

    private static void visualize(INDArray[] samples) {
        if (frame == null) {
            frame = new JFrame();
            frame.setTitle("Viz");
            frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
            frame.setLayout(new BorderLayout());

            panel = new JPanel();

            panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8));
            frame.add(panel, BorderLayout.CENTER);
            frame.setVisible(true);
        }

        panel.removeAll();

        for (INDArray sample : samples) {
            panel.add(getImage(sample));
        }

        frame.revalidate();
        frame.pack();
    }

    private static JLabel getImage(INDArray tensor) {
        BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_INT_ARGB_PRE);
        for (int i = 0; i < 784; i++) {
            int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255);
            bi.getRaster().setSample(i % 28, i / 28, 0, pixel);
        }
        ImageIcon orig = new ImageIcon(bi);
        Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE);

        ImageIcon scaled = new ImageIcon(imageScaled);

        return new JLabel(scaled);
    }
}
